-
Notifications
You must be signed in to change notification settings - Fork 15.2k
(WIP) [MLIR] Add a new interface for "IR parameterization" #78544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-ods Author: Mehdi Amini (joker-eph) ChangesThis implements the ability to define "meta program": that is a mechanism similar to C++ template. So as an example, this input IR:
Will see the @callee parametric function be instantiated for each call-site:
Patch is 61.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78544.diff 32 Files Affected:
diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index 844601f8f6837c4..68bae5d3f991dac 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -154,7 +154,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
"bool", "isDeclaration", (ins), [{}],
/*defaultImplementation=*/[{
// By default, assume that the operation defines a symbol.
- return false;
+ return false;
}]
>,
];
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index d81298bb4daf014..2f3e34e266e3fbc 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
add_mlir_interface(ParallelCombiningOpInterface)
+add_mlir_interface(ParametricSpecializationOpInterface)
add_mlir_interface(RuntimeVerifiableOpInterface)
add_mlir_interface(ShapedOpInterfaces)
add_mlir_interface(SideEffectInterfaces)
diff --git a/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.h b/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.h
new file mode 100644
index 000000000000000..88770e7239ac0f1
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.h
@@ -0,0 +1,25 @@
+//===- ParametricSpecializationOpInterface.h - Parallel combining op interface
+//---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the operation interface for ops that parallel combining
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES_H_
+#define MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/SymbolTable.h"
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/ParametricSpecializationOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES_H_
diff --git a/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.td b/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.td
new file mode 100644
index 000000000000000..e3c12d6b4b60f98
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.td
@@ -0,0 +1,46 @@
+//===-- ParametricSpecializationOpInterface.td -------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES
+#define MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def ParametricOpInterface : OpInterface<"ParametricOpInterface"> {
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<"",
+ "::mlir::LogicalResult", "specialize", (ins
+ "::mlir::DictionaryAttr":$params)>,
+ InterfaceMethod<"",
+ "::mlir::LogicalResult", "checkOperand", (ins
+ "::mlir::OpOperand &":$operand,
+ "::mlir::Type":$concreteType)>,
+ InterfaceMethod<"Only for symbol operation which will be cloned, mangle in-place.",
+ "::mlir::FailureOr<::mlir::StringAttr>", "getMangledName", (ins
+ "::mlir::DictionaryAttr":$metaArgs), "", [{
+ return failure();
+ }]
+>,
+ ];
+}
+
+def SpecializingOpInterface : OpInterface<"SpecializingOpInterface"> {
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<"",
+ "::mlir::SymbolRefAttr", "getTarget", (ins)>,
+ InterfaceMethod<"",
+ "::mlir::DictionaryAttr", "getMetaArgs", (ins)>,
+ InterfaceMethod<"",
+ "::mlir::LogicalResult", "setSpecializedTarget", (ins
+ "::mlir::SymbolOpInterface":$target)>,
+ ];
+}
+
+#endif // MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES
diff --git a/mlir/include/mlir/Transforms/ParametricSpecialization.h b/mlir/include/mlir/Transforms/ParametricSpecialization.h
new file mode 100644
index 000000000000000..1bbe3e2a557ef11
--- /dev/null
+++ b/mlir/include/mlir/Transforms/ParametricSpecialization.h
@@ -0,0 +1,11 @@
+//===- RemoveDeadValues.h - Specialize Meta Program -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Operation.h"
+
+namespace mlir {}
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index e7c76e70ed6b5d7..1998b66f168f36a 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -13,6 +13,7 @@ set(LLVM_OPTIONAL_SOURCES
LoopLikeInterface.cpp
MemorySlotInterfaces.cpp
ParallelCombiningOpInterface.cpp
+ ParametricSpecializationOpInterface.cpp
RuntimeVerifiableOpInterface.cpp
ShapedOpInterfaces.cpp
SideEffectInterfaces.cpp
@@ -80,6 +81,7 @@ add_mlir_library(MLIRLoopLikeInterface
add_mlir_interface_library(MemorySlotInterfaces)
add_mlir_interface_library(ParallelCombiningOpInterface)
+add_mlir_interface_library(ParametricSpecializationOpInterface)
add_mlir_interface_library(RuntimeVerifiableOpInterface)
add_mlir_interface_library(ShapedOpInterfaces)
add_mlir_interface_library(SideEffectInterfaces)
diff --git a/mlir/lib/Interfaces/ParametricSpecializationOpInterface.cpp b/mlir/lib/Interfaces/ParametricSpecializationOpInterface.cpp
new file mode 100644
index 000000000000000..80fc2caf0d12aca
--- /dev/null
+++ b/mlir/lib/Interfaces/ParametricSpecializationOpInterface.cpp
@@ -0,0 +1,13 @@
+//===- ParametricSpecializationOpInterface.cpp ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/ParametricSpecializationOpInterface.h"
+#include "mlir/Support/LogicalResult.h"
+
+/// Include the definitions of the interface.
+#include "mlir/Interfaces/ParametricSpecializationOpInterface.cpp.inc"
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index af51a4ab1157f15..8254f9d212c6035 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_library(MLIRTransforms
LoopInvariantCodeMotion.cpp
Mem2Reg.cpp
OpStats.cpp
+ ParametricSpecialization.cpp
PrintIR.cpp
RemoveDeadValues.cpp
SCCP.cpp
@@ -32,6 +33,7 @@ add_mlir_library(MLIRTransforms
MLIRFunctionInterfaces
MLIRLoopLikeInterface
MLIRMemorySlotInterfaces
+ MLIRParametricSpecializationOpInterface
MLIRPass
MLIRRuntimeVerifiableOpInterface
MLIRSideEffectInterfaces
diff --git a/mlir/lib/Transforms/ParametricSpecialization.cpp b/mlir/lib/Transforms/ParametricSpecialization.cpp
new file mode 100644
index 000000000000000..fcc3daacad447dd
--- /dev/null
+++ b/mlir/lib/Transforms/ParametricSpecialization.cpp
@@ -0,0 +1,13 @@
+//===- RemoveDeadValues.cpp - Specialize Meta Program ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/ParametricSpecialization.h"
+
+using namespace mlir;
+
+void specialize(Operation *op) {}
\ No newline at end of file
diff --git a/mlir/test/Parametric/ops.mlir b/mlir/test/Parametric/ops.mlir
new file mode 100644
index 000000000000000..ed8c87cd48ccee9
--- /dev/null
+++ b/mlir/test/Parametric/ops.mlir
@@ -0,0 +1,18 @@
+
+
+testparametric.func @callee(%arg0: !testparametric.param<"A"> ) attributes { metaParams = ["A", "B"]} {
+ %value = testparametric.add %arg0, %arg0 : (!testparametric.param<"A">, !testparametric.param<"A">) -> !testparametric.param<"A">
+ testparametric.print_attr #testparametric.param<"B">
+ return
+}
+
+func.func @caller() {
+ %cst0 = arith.constant 0 : i32
+ %cst1 = arith.constant 1. : f32
+ %cst2 = arith.constant 2. : f64
+ testparametric.call @callee(%cst0) meta = {"A" = i32, "B" = 32 : i64 } : (i32) -> ()
+ testparametric.call @callee(%cst0) meta = {"A" = i32, "B" = 32 : i64 } : (i32) -> ()
+ testparametric.call @callee(%cst1) meta = {"A" = f32, "B" = 64 : i64 } : (f32) -> ()
+ testparametric.call @callee(%cst2) meta = {"A" = f64, "B" = 128 : i64 } : (f64) -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 30a17c201ff7635..8c1be74f15899ec 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -17,6 +17,7 @@ add_subdirectory(SPIRV)
add_subdirectory(Tensor)
add_subdirectory(Test)
add_subdirectory(TestDyn)
+add_subdirectory(TestParametric)
add_subdirectory(Tosa)
add_subdirectory(Transform)
add_subdirectory(Vector)
diff --git a/mlir/test/lib/Dialect/TestParametric/CMakeLists.txt b/mlir/test/lib/Dialect/TestParametric/CMakeLists.txt
new file mode 100644
index 000000000000000..dcc79f15993b7e9
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/CMakeLists.txt
@@ -0,0 +1,68 @@
+set(LLVM_OPTIONAL_SOURCES
+ TestParametricDialect.cpp
+)
+
+set(LLVM_TARGET_DEFINITIONS TestParametricInterfaces.td)
+mlir_tablegen(TestParametricAttrInterfaces.h.inc -gen-attr-interface-decls)
+mlir_tablegen(TestParametricAttrInterfaces.cpp.inc -gen-attr-interface-defs)
+mlir_tablegen(TestParametricTypeInterfaces.h.inc -gen-type-interface-decls)
+mlir_tablegen(TestParametricTypeInterfaces.cpp.inc -gen-type-interface-defs)
+mlir_tablegen(TestParametricOpInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(TestParametricOpInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRTestParametricInterfaceIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TestParametricOps.td)
+mlir_tablegen(TestParametricAttrDefs.h.inc -gen-attrdef-decls)
+mlir_tablegen(TestParametricAttrDefs.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRTestParametricAttrDefIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TestParametricTypeDefs.td)
+mlir_tablegen(TestParametricTypeDefs.h.inc -gen-typedef-decls -typedefs-dialect=testparametric)
+mlir_tablegen(TestParametricTypeDefs.cpp.inc -gen-typedef-defs -typedefs-dialect=testparametric)
+add_public_tablegen_target(MLIRTestParametricTypeDefIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TestParametricOps.td)
+mlir_tablegen(TestParametricOps.h.inc -gen-op-decls)
+mlir_tablegen(TestParametricOps.cpp.inc -gen-op-defs)
+mlir_tablegen(TestParametricOpsDialect.h.inc -gen-dialect-decls -dialect=testparametric)
+mlir_tablegen(TestParametricOpsDialect.cpp.inc -gen-dialect-defs -dialect=testparametric)
+add_public_tablegen_target(MLIRTestParametricOpsIncGen)
+
+# Exclude testparametrics from libMLIR.so
+add_mlir_library(MLIRTestParametricDialect
+ TestParametricAttributes.cpp
+ TestParametricDialect.cpp
+ TestParametricInterfaces.cpp
+ TestParametricTypes.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ DEPENDS
+ MLIRTestParametricAttrDefIncGen
+ MLIRTestParametricInterfaceIncGen
+ MLIRTestParametricTypeDefIncGen
+ MLIRTestParametricOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRControlFlowInterfaces
+ MLIRDataLayoutInterfaces
+ MLIRDerivedAttributeOpInterface
+ MLIRDestinationStyleOpInterface
+ MLIRDialect
+ MLIRDLTIDialect
+ MLIRFuncDialect
+ MLIRFunctionInterfaces
+ MLIRFuncTransforms
+ MLIRIR
+ MLIRInferIntRangeInterface
+ MLIRInferTypeOpInterface
+ MLIRLinalgDialect
+ MLIRLinalgTransforms
+ MLIRLLVMDialect
+ MLIRPass
+ MLIRReduce
+ MLIRTensorDialect
+ MLIRTransformUtils
+ MLIRTransforms
+)
+
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricAttrDefs.td b/mlir/test/lib/Dialect/TestParametric/TestParametricAttrDefs.td
new file mode 100644
index 000000000000000..c9133a99a344136
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricAttrDefs.td
@@ -0,0 +1,38 @@
+//===-- TestAttrDefs.td - Test dialect attr definitions ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TableGen data attribute definitions for Test dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TESTPARAMETRIC_ATTRDEFS
+#define TESTPARAMETRIC_ATTRDEFS
+
+// To get the test dialect definition.
+include "TestParametricDialect.td"
+include "mlir/Dialect/Utils/StructuredOpsUtils.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
+include "mlir/IR/EnumAttr.td"
+include "mlir/IR/OpAsmInterface.td"
+
+// All of the attributes will extend this class.
+class TestParametric_Attr<string name, list<Trait> traits = []>
+ : AttrDef<TestParametric_Dialect, name, traits>;
+
+def TestParametric_ParamAttr : TestParametric_Attr<"Param"> {
+ let mnemonic = "param";
+ // List of type parameters.
+ let parameters = (
+ ins
+ "::mlir::StringAttr":$ref
+ );
+ let assemblyFormat = "`<` $ref `>`";
+}
+
+#endif // TESTPARAMETRIC_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.cpp b/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.cpp
new file mode 100644
index 000000000000000..a5dde555b3e891e
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.cpp
@@ -0,0 +1,42 @@
+//===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains attributes defined by the TestDialect for testing various
+// features of MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestParametricAttributes.h"
+#include "TestParametricDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/bit.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace testparametric;
+
+//===----------------------------------------------------------------------===//
+// TestParametricDialect
+//===----------------------------------------------------------------------===//
+
+#define GET_ATTRDEF_CLASSES
+#include "TestParametricAttrDefs.cpp.inc"
+
+void TestParametricDialect::registerAttributes() {
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "TestParametricAttrDefs.cpp.inc"
+ >();
+}
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.h b/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.h
new file mode 100644
index 000000000000000..054fc0f598d0a4d
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.h
@@ -0,0 +1,34 @@
+//===- TestTypes.h - MLIR Test Dialect Types --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains types defined by the TestDialect for testing various
+// features of MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTPARAMETRICATTRIBUTES_H
+#define MLIR_TESTPARAMETRICATTRIBUTES_H
+
+#include <tuple>
+
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+
+#include "TestParametricAttrInterfaces.h.inc"
+#include "TestParametricOpEnums.h.inc"
+#include "mlir/IR/DialectResourceBlobManager.h"
+
+namespace testparametric {} // namespace testparametric
+
+#define GET_ATTRDEF_CLASSES
+#include "TestParametricAttrDefs.h.inc"
+
+#endif // MLIR_TESTPARAMETRICATTRIBUTES_H
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.cpp b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.cpp
new file mode 100644
index 000000000000000..693d93910e81886
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.cpp
@@ -0,0 +1,297 @@
+//===- TestParametricDialect.cpp - MLIR Dialect for Testing
+//----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestParametricDialect.h"
+#include "TestParametricAttributes.h"
+#include "TestParametricInterfaces.h"
+#include "TestParametricTypes.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/ODSSupport.h"
+#include "mlir/IR/OperationSupport.h"
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/Base64.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <cstdint>
+#include <numeric>
+#include <optional>
+
+// Include this before the using namespace lines below to
+// test that we don't have namespace dependencies.
+#include "TestParametricOpsDialect.cpp.inc"
+
+using namespace mlir;
+using namespace testparametric;
+
+void TestParametricDialect::initialize() {
+ registerAttributes();
+ registerTypes();
+ addOperations<
+#define GET_OP_LIST
+#include "TestParametricOps.cpp.inc"
+ >();
+}
+void testparametric::registerTestParametricDialect(DialectRegistry ®istry) {
+ registry.insert<TestParametricDialect>();
+}
+
+#include "TestParametricOpInterfaces.cpp.inc"
+#include "TestParametricTypeInterfaces.cpp.inc"
+
+#define GET_OP_CLASSES
+#include "TestParametricOps.cpp.inc"
+
+::mlir::ParseResult ParametricFuncOp::parse(mlir::OpAsmParser &parser,
+ mlir::Ope...
[truncated]
|
This implements the ability to define "meta program": that is a mechanism similar to C++ template. So as an example, this input IR: ``` testparametric.func @callee(%arg0: !testparametric.param<"A"> ) attributes { metaParams = ["A", "B"]} { %value = testparametric.add %arg0, %arg0 : (!testparametric.param<"A">, !testparametric.param<"A">) -> !testparametric.param<"A"> testparametric.print_attr #testparametric.param<"B"> return } func.func @caller() { %cst0 = arith.constant 0 : i32 %cst1 = arith.constant 1. : f32 %cst2 = arith.constant 2. : f64 testparametric.call @callee(%cst0) meta = {"A" = i32, "B" = 32 : i64 } : (i32) -> () testparametric.call @callee(%cst1) meta = {"A" = f32, "B" = 64 : i64 } : (f32) -> () testparametric.call @callee(%cst2) meta = {"A" = f64, "B" = 128 : i64 } : (f64) -> () return } ``` Will see the @callee parametric function be instantiated for each call-site: ``` func.func @caller() { %c0_i32 = arith.constant 0 : i32 %cst = arith.constant 1.000000e+00 : f32 %cst_0 = arith.constant 2.000000e+00 : f64 testparametric.call @callee$__mlir_instance__$A$i32$B$32(%c0_i32) meta = {} : (i32) -> () testparametric.call @callee$__mlir_instance__$A$f32$B$64(%cst) meta = {} : (f32) -> () testparametric.call @callee$__mlir_instance__$A$f64$B$128(%cst_0) meta = {} : (f64) -> () return } testparametric.func @callee$__mlir_instance__$A$f32$B$64(%arg0: f32) { %0 = add %arg0, %arg0 : (f32, f32) -> f32 print_attr 64 : i64 return } testparametric.func @callee$__mlir_instance__$A$f64$B$128(%arg0: f64) { %0 = add %arg0, %arg0 : (f64, f64) -> f64 print_attr 128 : i64 return } testparametric.func @callee$__mlir_instance__$A$i32$B$32(%arg0: i32) { %0 = add %arg0, %arg0 : (i32, i32) -> i32 print_attr 32 : i64 return } ```
69c09a6
to
db58552
Compare
This implements the ability to define "meta program": that is a mechanism similar to C++ template.
So as an example, this input IR:
Will see the @callee parametric function be instantiated for each call-site:
This is in early stages and is only a proof of concept right now. The test pass needs to be refactored into
utilities, with more verification and facilities for manipulating these.
More work will be needed to be able express arbitrary meta-expression as well, starting with arithmetic on numerical
parameters.